#include "SubElementSamplingDriver.h"

#include "VecBasis.h"
#include "BasisOps.h" 
#include "FileNameInfo.h"
#include "BasisFileStream.h"
#include "VecBasisFile.h"
#include "SimpleBuffer.h"
#include "RenumberingUtils.h"
#include "MeshDesc.h"

#include <Driver.d/GeoSource.h>
#include <Driver.d/Domain.h>
#include <Math.d/Vector.h>
#include <Timers.d/StaticTimers.h>
#include <Utils.d/Connectivity.h>
#include <Element.d/Element.h>

#include <cstddef>
#include <algorithm>
#include <numeric>
#include <iterator>
#include <set>
#include <map>
#include <vector>
#include <memory>
#include <fstream>
#include <string>
#include <utility>

#include <cassert>
#include <iostream>

extern GeoSource *geoSource;


namespace Rom {

SubElementSamplingDriver::SubElementSamplingDriver(Domain *d) :
  ElementSamplingDriver<std::vector<double>,size_t>(d)
{}

void
SubElementSamplingDriver::preProcess()
{
  domain_->makeAllDOFs();
  StaticTimers dummyTimes;
  GenFullSquareMatrix<double> *dummyGeomKelArray = NULL;
  const bool buildMelArray = true;
  domain_->computeGeometricPreStress(corotators_, geomState_, kelArray_, &dummyTimes, dummyGeomKelArray, melArray_, buildMelArray);
  if(domain_->nDirichlet() > 0) {
    geomState_->updatePrescribedDisplacement(domain_->getDBC(), domain_->nDirichlet(), domain_->getNodes());
  }
}

void
SubElementSamplingDriver::getGlobalWeights(Vector &solution, std::vector<double> &lweights, std::vector<int> &lelemids, bool verboseFlag)
{
  const FileNameInfo fileInfo;
  std::set<int> sampleElemRanks;
  {
    for (int iElem = 0; iElem != elementCount(); ++iElem) {
      if (solution[iElem] > 0.0) {
        sampleElemRanks.insert(sampleElemRanks.end(), iElem);
      }
    }
  }

  std::vector<int> packedToInput(elementCount());
  Elemset &inputElemSet = *(geoSource->getElemSet());
  for (int iElem = 0, iElemEnd = inputElemSet.size(); iElem != iElemEnd; ++iElem) {
    Element *elem = inputElemSet[iElem];
    if (elem) {
      const int iPackElem = domain_->glToPackElem(iElem);
      if(iPackElem >= 0) {
        assert(iPackElem < packedToInput.size());
        packedToInput[iPackElem] = iElem;
      }
    }
  }

  std::vector<int> sampleElemIds;
  sampleElemIds.reserve(sampleElemRanks.size());
  std::map<int, double> weights;
  for (std::set<int>::const_iterator it = sampleElemRanks.begin(), it_end = sampleElemRanks.end(); it != it_end; ++it) {
    const int elemRank = packedToInput[*it];
    weights.insert(std::make_pair(elemRank, solution[*it]));
    sampleElemIds.push_back(elemRank);
  }

  for(int i = 0; i < solution.size(); i++) {
    lelemids.push_back(packedToInput[i]);
    lweights.push_back(solution[i]);
  }
}


} // end namespace Rom

Rom::DriverInterface *subElementSamplingDriverNew(Domain *d) {
  return new Rom::SubElementSamplingDriver(d);
}
